import torch
from torch.autograd import Variable
from torch.autograd.function import Function


class ste_round(Function):
    """ste_round.
    Use round as forward function and 1 as backward gradients, like ste.
    """

    @staticmethod
    def forward(ctx, input_):
        #  ctx.save_for_backward(input_, range_table['backward'], func_table['backward'])
        outputs = torch.round(input_)
        return outputs

    @staticmethod
    def backward(ctx, grad_input_):
        grad_output_ = grad_input_
        return grad_output_


def ste_round_func(input):
    return ste_round.apply(input)
